import datetime as dt
import re
import traceback
from functools import wraps

import connectorx as cx
import pandas as pd
from croniter import croniter
from pandas._libs.lib import infer_dtype
from pangres import upsert, aupsert
from sqlalchemy import create_engine, VARCHAR, TIMESTAMP, Float, Integer, BIGINT, text
from sqlalchemy.ext.asyncio import create_async_engine

from hysdata import utils
from hysdata.models import Query
from hysdata.utils import load_config, store_config


def table_exists(table_name: str, con):
    """
    检查表是否已经存在
    :param table_name:
    :param con:
    :return:
    """
    df = pd.read_sql(
        "SELECT to_regclass('public.{table_name}');".format(
            table_name=table_name), con)
    return df['to_regclass'].iloc[0] is not None


async def async_table_exists(table_name: str, async_engine):
    async with async_engine.connect() as connection:
        proxy = await connection.execute(
            text("SELECT to_regclass('public.{table_name}');".format(table_name=table_name)))
        results = proxy.all()
        if results[0][0] is None:
            return False
        else:
            return True


def add_id(table_name, engine):
    """
    添加唯一索引和自增id,自增id用于connector-x查询加速partition,唯一索引用于多次插入数据去重
    :param table_name: 表名
    :param unique_cols:
    :param engine:
    :return:
    """
    with engine.connect() as connection:
        result = connection.execute(
            text("ALTER TABLE public.{table} ADD COLUMN  IF NOT EXISTS  id serial;".format(table=table_name)))
        return result


async def async_add_id(table_name, async_engine):
    async with async_engine.connect() as connection:
        await connection.execute(
            text("ALTER TABLE public.{table} ADD COLUMN  IF NOT EXISTS  id serial;".format(table=table_name)))
        await connection.commit()


def normalize_code(codes) -> list:
    """规范化代码"""
    nums = [re.findall("\\d{6}", code)[0] for code in codes]
    codes = []
    for num in nums:
        if num[0] in ['0', '1', '3']:
            codes.append(num + ".SZ")
        elif num[0] in ['8', '4']:
            codes.append(num + ".BJ")
        else:
            codes.append(num + ".SH")
    return codes


def normal_cols(df: pd.DataFrame) -> pd.DataFrame:
    """
    日期，股票代码自动转换格式
    :param df:
    :return:
    """
    df_ret = df.copy()
    for i, j in zip(df.columns, df.dtypes):
        if str(i) in ['公告日期', '实际公告日期', '报告期', 'date', 'updated', 'time', '更新时间'] or '日期' in str(i):
            if len(df_ret[i].iloc[0]) == 5:
                df_ret[i] = dt.datetime.now().strftime("%Y") + "-" + df_ret[i]
            df_ret[i] = pd.to_datetime(df_ret[i])
        elif str(i) in ['code', '证券代码'] or '代码' in str(i):
            if str(j) != 'object':
                df_ret[i] = df_ret[i].astype(str).str.zfill(6)
            df_ret[i] = normalize_code(df_ret[str(i)].tolist())
    return df_ret


def rename_col_names(cols) -> list:
    """
    列名中替换掉空格和圆括号
    :param cols:
    :return:
    """
    return [col.replace(" ", "_").replace("(", "$").replace(")", "$") for col in cols]


def dtypes_normal(df: pd.DataFrame) -> dict:
    """
    根据df的列类型，生成类型词典
    :param df:
    :return:
    """
    type_dict = {}
    index_types = [d for d in df.index.dtypes] if df.index.inferred_type == 'mixed' else [df.index.inferred_type]
    for i, j in zip(list(df.columns) + list(df.index.names), list(df.dtypes) + index_types):
        if i in df.columns:
            dtype = infer_dtype(df[i].values, skipna=True)
        else:
            dtype = infer_dtype(df.index.get_level_values(i), skipna=True)
        if 'date' in dtype or 'time' in dtype:
            type_dict.update({i: TIMESTAMP()})
        elif "string" == dtype:
            if i in df.columns:
                len_series = df[i].str.len()
            else:
                len_series = df.index.get_level_values(i).str.len()
            type_dict.update({i: VARCHAR(length=max(255, len_series.max(skipna=True) * 2))})
        elif "float" in dtype or 'decimal' in dtype:
            type_dict.update({i: Float(precision=4, asdecimal=True)})
        elif "int32" in str(j):
            type_dict.update({i: Integer()})
        elif "int64" in str(j):
            type_dict.update({i: BIGINT()})
    return type_dict


class Store:
    def __init__(self, con_url: str = None, debug: bool = True, print_log: bool = True):
        """
        数据保存对象
        :param con_url: 数据库连接url
        :param debug: debug=True crontab表达式不生效
        :param print_log: print_log=True 执行的时候打印日志
        使用方法举例：
        store=Store(con_url="postgresql+psycopg2://username:password@host:port/db",debug=False)

        store.save(unique_cols=['hs_code'])
        def stock_list(client):
            return client.sync_call_query(api.stock_list())

        if __name__ == "__main__":
            client=HSClient()
            stock_list(client)
        """
        config = load_config()
        if con_url is None or len(con_url) < 1:
            self.con_url = config.get("con_url")
        else:
            self.con_url = con_url
            config['con_url'] = con_url
            store_config(config)
        self.debug = debug
        self.print_log = print_log
        self.engine = create_engine(self.con_url)
        self.async_engine = create_async_engine(self.con_url.replace("psycopg2", 'asyncpg'))

    def save_df(self, df: pd.DataFrame, table_name: str = None, unique_cols: list = []):
        """
        保存DataFrame
        :param df:要保存的DataFrame
        :param table_name:表名
        :param unique_cols: 唯一索引列
        :param replace:替换还是追加
        """
        exists = table_exists(table_name, self.engine)
        if df is not None and df.shape[0] > 0:
            chunksize = int(32766 / df.shape[1] / 2)
            df = df.set_index(unique_cols, drop=True)
            df.rename(
                columns={col: col.replace("%", "百分比").replace("(", "").replace(")", "").replace("（",
                                                                                                '').replace(
                    "）", "") for col in
                    df.columns}, inplace=True)
            df['updated_at'] = dt.datetime.now()
            upsert(con=self.engine,
                   df=df,
                   table_name=table_name,
                   if_row_exists='update',
                   dtype=dtypes_normal(df), chunksize=chunksize)
            if not exists:
                add_id(table_name, self.engine)

    async def async_save_df(self, df: pd.DataFrame, table_name: str = None, unique_cols: list = []):
        """
        保存DataFrame
        :param df:要保存的DataFrame
        :param table_name:表名
        :param unique_cols: 唯一索引列
        """
        exists = await async_table_exists(table_name, self.async_engine)
        if df is not None and df.shape[0] > 0:
            chunksize = int(32766 / df.shape[1] / 2)
            df = df.set_index(unique_cols, drop=True)
            df.rename(
                columns={col: col.replace("%", "百分比").replace("(", "").replace(")", "").replace("（",
                                                                                                '').replace(
                    "）", "") for col in
                    df.columns}, inplace=True)
            df['updated_at'] = dt.datetime.now()
            await aupsert(con=self.async_engine,
                          df=df,
                          table_name=table_name,
                          if_row_exists='update',
                          dtype=dtypes_normal(df), chunksize=chunksize)
            if not exists:
                await async_add_id(table_name, self.async_engine)

    def save(self, table_name: str = None, unique_cols: list = [], crontab: str = "* * * * *"):
        """
          保存数据注解，注解到返回值是个DataFrame的函数
          :param table_name:表名
          :param unique_cols: 唯一索引列
          :param crontab: crontab表达式，debug=False,在执行的时候会检查当前时间是否复合crontab表达式，不符合就跳过不执行
          :return:
        """
        debug = self.debug
        print_log = self.print_log

        def _memoize(fn):
            @wraps(fn)  # 自动复制函数信息
            def __memoize(*args, **kwargs):
                try:
                    if croniter.match(crontab, dt.datetime.now()) or debug:
                        table = table_name if table_name is not None else fn.__name__
                        exists = table_exists(table, self.engine)
                        if print_log:
                            print("------------------------------------------------------------------------------")
                            print(table + " start: " + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
                            print("位置参数:", args, "命名参数:", kwargs)
                        df = fn(*args, **kwargs)
                        if df is not None and df.shape[0] > 0:
                            chunksize = int(32766 / df.shape[1] / 2)
                            df = df.set_index(unique_cols, drop=True)
                            df.rename(
                                columns={col: col.replace("%", "百分比").replace("(", "").replace(")", "").replace("（",
                                                                                                                '').replace(
                                    "）", "") for col in
                                    df.columns}, inplace=True)
                            df['updated_at'] = dt.datetime.now()
                            upsert(con=self.engine,
                                   df=df,
                                   table_name=table,
                                   schema='public',
                                   if_row_exists='update',
                                   dtype=dtypes_normal(df), chunksize=chunksize)
                            if not exists:
                                add_id(table, self.engine)
                        if print_log:
                            print(table + "   end: " + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
                        return df
                except Exception as e:
                    print(e)
                    traceback.print_exc()

            return __memoize

        return _memoize

    def async_save(self, table_name: str = None, unique_cols: list = [], crontab: str = "* * * * *"):
        debug = self.debug
        print_log = self.print_log

        def _memoize(fn):
            @wraps(fn)
            async def __memoize(*args, **kwargs):
                try:
                    if croniter.match(crontab, dt.datetime.now()) or debug:
                        table = table_name if table_name is not None else fn.__name__
                        exists = await async_table_exists(table, self.async_engine)
                        if print_log:
                            print("------------------------------------------------------------------------------")
                            print(table + " start: " + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
                            print("位置参数:", args, "命名参数:", kwargs)
                        df = await fn(*args, **kwargs)
                        if df is not None and df.shape[0] > 0:
                            chunksize = int(32766 / df.shape[1] / 2)
                            df = df.set_index(unique_cols, drop=True)
                            df.rename(
                                columns={col: col.replace("%", "百分比").replace("(", "").replace(")", "").replace("（",
                                                                                                                '').replace(
                                    "）", "") for col in
                                    df.columns}, inplace=True)
                            df['updated_at'] = dt.datetime.now()
                            await aupsert(con=self.async_engine,
                                          df=df,
                                          table_name=table,
                                          schema='public',
                                          if_row_exists='update',
                                          dtype=dtypes_normal(df), chunksize=chunksize)
                            if not exists:
                                await async_add_id(table, self.async_engine)
                        if print_log:
                            print(table + "   end: " + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
                        return df
                except Exception as e:
                    print(e)
                    traceback.print_exc()

            return __memoize

        return _memoize

    def sync_query(self, query: Query, cn_column_names=False, connectorx=True):
        """
        根据API查询对象，从数据库返回查询的数据，但是有些函数查询值和返回值不同,这个方法只适用部分API
        :param query: 例子 api.stock_list()
        :param cn_column_names:是否返回中文列名
        :param connectorx: 是否使用connectorx加载
        :param partition_num: 多线程分区加载，分区数
        :return:
        """
        table = query.method
        sql = "select * from {table}".format(table=table)
        if len(query.params) > 0:
            sql = sql + " where " + (" and ".join([k + "=" + "'" + v + "'" for k, v in query.params.items()]))
        if connectorx:
            df = cx.read_sql(self.con_url.replace("+psycopg2:", ":"), sql)
        else:
            df = pd.read_sql(sql, self.engine)
        if df.shape[0] > 0:
            return utils.change_type_and_column_names(df, query, cn_column_names, change_types=False)
        else:
            return df

    def sync_sql(self, sql, cn_column_names=False, query: Query = None, connectorx=True):
        """
        执行sql查询数据库
        :param sql: 原始sql
        :param cn_column_names: 是否返回中文列名,True 需要传入对应的query对象来从query对象找到列名的中英文词典
        :param query:api.stock_list()
        :param partition_num: 线程分区加载，分区数
        :param connectorx: 是否启动 connectorx 加速数据load
        :return:
        """
        if connectorx:
            df = cx.read_sql(self.con_url.replace("+psycopg2:", ":"), sql)
        else:
            df = pd.read_sql(sql, self.engine)
        if df.shape[0] > 0 and cn_column_names and query is not None:
            return utils.change_type_and_column_names(df, query, cn_column_names, change_types=False)
        else:
            return df


__all__ = ["Store", 'normalize_code', 'normal_cols', 'dtypes_normal']
